#!/usr/bin/env python3

import argparse
import os
import sys
import subprocess
import telnetlib
from platform import system

FLASH_BASE     = 0x80000000
FLASH_MAX_SIZE = 0x1000000
FLASH_DRIVER   = 'agrv'
IMAGE_TYPES    = ["bin", "ihex", "elf", "s19", "mem", "build"]
OPENOCD_EXE    = "openocd"
OPENOCD_PORT   = 4444
OPENOCD_PROMPT = b"\r> "

def auto_int(x):
  return int(x, 0)

def file_len(arg):
  file_len.count += 1
  return open(arg, 'w') if file_len.count % 2 else auto_int(arg)
file_len.count = 0

parser = argparse.ArgumentParser(description='OpenOCD wrapper')

parser.add_argument('-n', '--dry-run',        help='dry run',                                            action='store_true')
parser.add_argument('-d', '--exe-dir',        help='executable path',                                    type=str, default="")
parser.add_argument('-A', '--adapter',        help='adapter type',                                       type=str)
parser.add_argument('-s', '--speed',          help='adapter speed',                                      type=int)
parser.add_argument('-j', '--jtag',           help='use JTAG interface instead of SWD',                  action='store_true')
parser.add_argument('-I', '--device',         help='device type',                                        type=str, default='agrv2k')
parser.add_argument('-J', '--device-id',      help='device ID',                                          type=str)
parser.add_argument('-K', '--logic-config',   help='logic CONFIG',                                       type=str)
parser.add_argument('-M', '--logic-bin',      help='logic BIN',                                          type=str, default='')
parser.add_argument('-X', '--args',           help='additional OpenOCD arguments',                       action='append', default=[])
parser.add_argument('-g', '--progress',       help='progress bar stream when writing flash',             type=int, default=2, choices=[0, 1, 2])
parser.add_argument('-R', '--srst',           help='reset target by asserting srst while connecting',    action='store_true')
parser.add_argument('-i', '--info',           help='flash information',                                  action='store_true')
parser.add_argument('-t', '--halt',           help='halt instead of reset for start(1)/end(2)/both(3)',  type=int, default=0, const=3, choices=[0, 1, 2, 3], nargs='?')
parser.add_argument('-T', '--disable-telnet', help='disable auto telnet if OpenOCD is already running',  action='store_true')
parser.add_argument('-o', '--option-erase',   help='erase option bytes',                                 action='store_true')
parser.add_argument('-u', '--unlock',         help='disable read protection (with auto erase)',          action='store_true')
parser.add_argument('-L', '--lock',           help='Enable read protection',                             action='store_true')
parser.add_argument('-p', '--read-protect',   help='enable read protection',                             action='store_true')
parser.add_argument('-f', '--fpga-config',    help='write FPGA configuration file',                      type=argparse.FileType('r'))
parser.add_argument('-F', '--fpga-address',   help='write FPGA configuration address',                   type=auto_int)
parser.add_argument('-r', '--read-image',     help='read image file from flash with specified length',   nargs=2, metavar=('file', 'length'), type=file_len)
parser.add_argument('-E', '--erase-all',      help='erase the entire flash',                             action='store_true')
parser.add_argument('-e', '--erase-length',   help='erase flash for a given length',                     type=auto_int)
parser.add_argument('-w', '--write-image',    help='write image file to flash and verify',               type=argparse.FileType('r'))
parser.add_argument('-W', '--Write-image',    help='write image file to flash and verify without erase', type=argparse.FileType('r'))
parser.add_argument('-v', '--verify-image',   help='verify image file',                                  type=argparse.FileType('r'), nargs='?')
parser.add_argument('-a', '--address-offset', help='address (.bin) or offset for the image',             type=auto_int)
parser.add_argument('-y', '--file-type',      help='file type for the image',                            choices=IMAGE_TYPES)
parser.add_argument('-m', '--memory-op',      help='use memory operations instead of flash',             action='store_true')
parser.add_argument('-N', '--download',       help='download binaries from gen_download',                type=argparse.FileType('r'))
parser.add_argument('-C', '--Cmd',            help='prefix commands to be run',                          action='append', default=[])
parser.add_argument('-c', '--cmd',            help='additional commands to be run',                      action='append', default=[])
parser.add_argument('-l', '--log',            help='log OpenOCD with debug',                             type=str)
parser.add_argument('-D', '--debug',          help='set OpenOCD debug level',                            type=int)
parser.add_argument('-x', '--no-exit',        help='do not exit OpenOCD',                                action='store_true')
parser.add_argument('-z', '--rtt',            help='start RTT server',                                   action='store_true')
parser.add_argument('-O', '--offline-mode',   help='Offline operations for CMSIS-DAP',                   nargs='?', type=int, default=0, const=1)
parser.add_argument('--flash-driver',         help='flash driver',                                       type=str, default=FLASH_DRIVER),
parser.add_argument('--flash-base',           help='flash base',                                         type=auto_int, default=FLASH_BASE)

def get_image_info(file, args):
  (address_offset, file_type, memory_op, offline_mode) = (args.address_offset, args.file_type, args.memory_op, args.offline_mode)
  if not file:
    file_type = "bin"
  elif not file_type:
    file_type = os.path.splitext(file.name)[-1].lstrip(os.extsep).lower()
  if file_type not in IMAGE_TYPES or offline_mode:
    file_type = "bin"
  if not address_offset:
    address_offset = 0 if offline_mode else args.flash_base if file_type == "bin" else 0
  if file_type == "bin": # Can auto decide operation type for bin files
    memory_op = address_offset < args.flash_base or address_offset >= args.flash_base + FLASH_MAX_SIZE
  return (address_offset, file_type, memory_op)

def get_image_cmds(file, read, write, erase, args, length=0):
  cmds = []
  if file or (not read and not write and erase): # file or pure erase
    (address_offset, file_type, memory_op) = get_image_info(file, args)
    file_name = file.name if file else ""
    erase_length = os.path.getsize(file.name) if file else args.erase_length
    # Separate erase with bin format for better message
    if args.offline_mode:
      if read:
        return [f"cmsis-dap-download read {{{file_name}}} 0x{address_offset:x} {length}"]
      if erase:
        cmds += [f"cmsis-dap-download erase 0x{address_offset:x} {erase_length}"]
      if write:
        cmds += [f"cmsis-dap-download write {{{file_name}}} 0x{address_offset:x}"]
      if write or not erase:
        cmds += [f"cmsis-dap-download verify {{{file_name}}} 0x{address_offset:x}"]
      return cmds
    if read:
      return [f"dump_image {{{file_name}}} 0x{address_offset:x} {length}"]
    erase_cmd = "erase " if erase and not memory_op and file_type != "bin" else ""
    if erase and not memory_op and len(erase_cmd) == 0:
      cmds += [f"flash erase_address pad 0x{address_offset:x} {erase_length}"]
    if write:
      if memory_op:
        cmds += [f"load_image {{{file_name}}} 0x{address_offset:x} {file_type}"]
      else:
        cmds += [f"flash write_image {erase_cmd}{{{file_name}}} 0x{address_offset:x} {file_type}"]
    if write or not erase:
      cmds += [f"verify_image {{{file_name}}} 0x{address_offset:x} {file_type}"]
  return cmds

def run_telnet(args, oo_cmds):
  try:
    import psutil
  except:
    return 1

  if not args.disable_telnet and OPENOCD_EXE in (p.name() for p in psutil.process_iter()):
    print("OpenOCD is already running. Using telnet to control it.\n")

    if not args.halt:
      oo_cmds += ["resume"]

    if args.dry_run:
      print('\n'.join(oo_cmds))
      return 0
    try:
      telnet = telnetlib.Telnet("127.0.0.1", port=OPENOCD_PORT)
    except:
      print(f"Failed to telnet to OpenOCD on port {OPENOCD_PORT}!", file=sys.stderr)
      print("Trying to launch another OpenOCD instance...\n")
      return 1

    try:
      # Save the current working directory in telnet and switch to the directory running this script, since files used 
      # in this script may have relative paths. Restore the telnet directory upon finish.
      oo_cmds = [f"variable telnet_dir [pwd]; cd {os.getcwd()}"] + oo_cmds + ["cd $telnet_dir", ""]
      for cmd in oo_cmds:
        lines = telnet.expect([OPENOCD_PROMPT])[2]
        # Print all read text except for the last line (the prompt)
        print("\n".join(lines.decode().split("\n")[:-1]))
        telnet.write(str.encode(cmd + "\n"))
    except KeyboardInterrupt:
      print("telnet stopped by user\n")
      return 0
    except:
      print("telnet connection closed\n")
      return -1

    telnet.close()
    return 0

  return 1

def main():
  args = parser.parse_args(sys.argv[1:])
  adapter_speed = args.speed if args.speed else 12000 if args.jtag else 15000

  oo_args = [os.environ.get('OPENOCD_EXE') or os.path.join(args.exe_dir, OPENOCD_EXE)]
  for cmd in args.Cmd:
    oo_args += ["-c", cmd]
  oo_args += ["-s", os.path.dirname(sys.argv[0])]
  oo_args += ["-c", f"variable ADAPTER_SPEED {adapter_speed}"]
  oo_args += ["-c", f"variable ADAPTER {args.adapter}"] if args.adapter else []
  oo_args += ["-c", f"variable TRANSPORT jtag"] if args.jtag else []
  oo_args += ["-c", f"variable AGRV_PROGRESS {args.progress}"]
  oo_args += ["-c", f"variable CONNECT_UNDER_RESET 1"] if args.srst else []
  oo_args += ["-c", f"variable ADAPTER_OFFLINE 1"] if args.offline_mode else []
  oo_args += ["-f", f"{args.device}.cfg"]
  oo_args += args.args

  oo_cmds = []
  oo_cmds += [] if args.offline_mode else ["halt"] if args.halt & 0x1 else ["reset init"]
  if args.device_id:
    check_device_id = f"check_device_id {args.device_id}"
    oo_cmds += ["if {[catch {%s} msg]} { error \"Error: $msg.\" }"%check_device_id]
  if args.logic_config:
    check_logic = f"check_logic {args.logic_config} {{{args.logic_bin}}}"
    oo_cmds += ["if {[catch {%s} msg]} { error \"Error: $msg.\" }"%check_logic]

  # Operations start here.
  non_op_cmds = len(oo_cmds)
  if args.info:
    if args.offline_mode:
      oo_cmds += ["cmsis-dap-download info"]
    else:
      oo_cmds += [f"echo -n [flash info 0]; echo -n [{args.flash_driver} options_read 0]"]
  oo_cmds += [f"{args.flash_driver} options_erase 0; reset {'halt' if args.halt else 'init'}"] if args.option_erase else []
  oo_cmds += [f"{args.flash_driver} unlock 0; reset {'halt' if args.halt else 'init'}; flash probe 0"] if args.unlock else []
  oo_cmds += [f"{args.flash_driver} lock 0"] if args.lock else []
  oo_cmds += get_image_cmds(args.read_image[0] if args.read_image else None, True, False, False, args, args.read_image[1] if args.read_image else 0)
  oo_cmds += ([f"cmsis-dap-download erase"] if args.offline_mode else [f"{args.flash_driver} mass_erase 0"]) if args.erase_all else []
  oo_cmds += get_image_cmds(None, False, False, bool(args.erase_length), args)
  if args.fpga_config: 
    fpga_config_name = args.fpga_config.name
    oo_cmds += [f"{args.flash_driver} write_fpga_config {{{fpga_config_name}}}{f' 0x{args.fpga_address:x}' if args.fpga_address else ''}"]
  elif args.fpga_address:
    oo_cmds += [f"{args.flash_driver} options_write FPGA 0x{args.fpga_address:x}"]
  oo_cmds += get_image_cmds(args.write_image,  False, True,  True,  args)
  oo_cmds += get_image_cmds(args.Write_image,  False, True,  False, args)
  oo_cmds += get_image_cmds(args.verify_image, False, False, False, args)
  oo_cmds += [f"{args.flash_driver} lock 0"] if args.read_protect else [] # Do read protect at the end of all operations
  oo_cmds += [f"cmsis-dap-download run 0x{args.address_offset or 0:x} {args.offline_mode * 1000}", "shutdown"] if args.offline_mode >= 2 else []
  oo_cmds += [f"{args.flash_driver} download {'verify ' if '-v' in sys.argv else ''}{{{args.download.name}}}"] if args.download else []

  no_exit = args.no_exit or args.rtt or len(oo_cmds) == non_op_cmds and len(args.cmd) == 0
  oo_cmds += args.cmd
  oo_cmds += [] if no_exit or args.offline_mode or args.halt & 0x2 else ["reset halt"]

  telnet_ret = run_telnet(args, oo_cmds)
  if telnet_ret <= 0:
    return telnet_ret

  for cmd in oo_cmds:
    oo_args += ["-c", cmd]

  oo_args += ["-c", "rtt_start"] if args.rtt else []
  oo_args += ["-d" + (f"{args.debug}" if args.debug else ""), "-l", args.log] if args.log else []
  oo_args += [] if no_exit else ["-c", "shutdown"]

  if args.dry_run:
    print(' '.join([f"\"{arg}\"" for arg in oo_args]))
  else:
    try:
      return subprocess.run(oo_args).returncode
    except KeyboardInterrupt:
      return 0
    except Exception as e:
      print(str(e), file=sys.stderr)
      return -1

if __name__ == "__main__":
  sys.exit(main())
